# import libraries
from pyspark.sql import SparkSession
from pyspark import SparkConf
from pyspark.sql.types import *

from pyspark.sql.functions import col, count, when

from pyspark.ml.classification import LinearSVC

import pandas as pd

#################################################
# spark config
#################################################
mtaMaster = "spark://192.168.0.182:7077"

conf = SparkConf()
conf.setMaster(mtaMaster)

conf.set("spark.executor.memory", "24g")
conf.set("spark.driver.memory", "26g")
conf.set("spark.cores.max", 96)
conf.set("spark.driver.cores", 8)

conf.set("spark.serializer", "org.apache.spark.serializer.KryoSerializer")
conf.set("spark.kryoserializer.buffer", "256m")
conf.set("spark.kryoserializer.buffer.max", "256m")

conf.set("spark.default.parallelism", 24)

conf.set("spark.eventLog.enabled", "true")
conf.set("spark.eventLog.dir", "hdfs://192.168.0.182:9000/eventlog")
conf.set("spark.history.fs.logDirectory", "hdfs://192.168.0.182:9000/eventlog")

conf.set("spark.driver.maxResultSize", "2g")

conf.getAll()

#################################################
# create spark session
#################################################
spark = SparkSession.builder.appName('ML1_wf_on_NYT_x10_sim1_and_sim2_to_sim3_round1').config(conf=conf).getOrCreate()

sc = spark.sparkContext

# check things are working
print(sc)
print(sc.defaultParallelism)
print("SPARK CONTEXT IS RUNNING")


#################################################
# define major topic codes
#################################################

# major topic codes for loop (NO 23 IN THE NYT CORPUS)
majortopic_codes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 100]
#majortopic_codes = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 23, 100]


#################################################
# loop starts here
#################################################

for h in range(10):
    # read table from hdfs
    df_original = spark.read.parquet("hdfs://192.168.0.182:9000/input/Data_NYT_clean_SPARK_START_ML2_features.parquet").repartition(50)

    # check loaded data 
    print(df_original.printSchema())
    print(df_original.show())
    df_original.groupBy("majortopic").count().show(30, False)

    #################################################
    # prepare to log sample numbers
    #################################################

    columns = ["label", "non_label_all", "non_label_sample", "train_all"]

    df_numbers = pd.DataFrame(index=majortopic_codes, columns=columns)

    for i in majortopic_codes:
        #################################################
        # prepare df for svm requirements
        #################################################
        print("majortopic is:", i)

        # separate majortopic
        df_original = df_original.withColumn("label", when(df_original["majortopic"] == i, 1).otherwise(0))

        # label has to be double for SVM
        df_original = df_original.withColumn('label', df_original.label.cast(DoubleType()))

        #################################################
        # separate training and test sets
        #################################################

        df_train = df_original.where(col('sim') != 3)
        df_test = df_original.where(col('sim') == 3)

        # make training data proportional with regards to label occurrence frequency
        df_train_mtc = df_train.where(col('label') == 1)
        df_train_non_mtc = df_train.where(col('label') == 0)

        df_train_count = df_train.count()
        df_train_mtc_count = df_train_mtc.count()
        df_train_non_mtc_count = df_train_non_mtc.count()
        print("Rows in training DataFrame with label = ", df_train_mtc_count)
        print("Rows in training DataFrame without label = ", df_train_non_mtc_count)

        if df_train_mtc_count/df_train_non_mtc_count < 0.1:
            if df_train_mtc_count*10 < df_train_count//10:
                sample_num = df_train_count//10
            else: sample_num = df_train_mtc_count*10
            print("sample_num = ", sample_num)
            print("df_train_non_mtc = ", df_train_non_mtc_count)
            sampling_fraction = sample_num/df_train_non_mtc_count
            print("sampling_fraction = ", sampling_fraction)
            df_train_non_mtc = df_train_non_mtc.sample(False, sampling_fraction)
            df_train_non_mtc_sample = df_train_non_mtc.count()
            print("Rows in training DataFrame without label = ", df_train_non_mtc_sample)
            df_train = df_train_mtc.union(df_train_non_mtc)
            # numbers to logtable
            df_numbers["non_label_sample"].loc[i] = df_train_non_mtc_sample
            df_numbers["train_all"].loc[i] = df_train_mtc_count + df_train_non_mtc_sample
        else:
            # numbers to logtable
            df_numbers["non_label_sample"].loc[i] = df_train_non_mtc_count
            df_numbers["train_all"].loc[i] = df_train_count

        # numbers to logtable
        df_numbers["label"].loc[i] = df_train_mtc_count
        df_numbers["non_label_all"].loc[i] = df_train_non_mtc_count
        print(df_numbers)

        # NOTE: this type of copying wouldn't work in python, but does work in pyspark!
        df_train_orig = df_train
        df_test_orig = df_test
        df_loop = 0
        df_train_mtc = 0
        df_train_non_mtc = 0

        print("Rows in training DataFrame = ", df_train.count())
        print("Rows in test DataFrame = ", df_test.count())


        #################################################
        # SVM
        #################################################

        for j in range(10):
            df_train = df_train_orig
            df_test = df_test_orig

            # define svm
            lsvc = LinearSVC(featuresCol='features', labelCol='label', maxIter=10, regParam=0.1)

            # train the model.
            lsvcModel = lsvc.fit(df_train)

            print("fit model finished, starting scoring:", j)

            # score the model on test data.
            predictions = lsvcModel.transform(df_test)

            df_train = 0
            df_test = 0
            lsvcModel = 0

            print(predictions.printSchema())
            print(predictions.show())

            df_write = predictions.select("doc_id", "prediction")

            predictions = 0

            df_write = df_write.withColumn('prediction', df_write.prediction.cast(IntegerType()))
            df_write = df_write.withColumn('prediction', df_write.prediction * i)
            new_col_name = 'prediction_{i}'.format(i=i)
            df_write = df_write.withColumnRenamed('prediction', new_col_name)

            # write partial result to parquet
            dest_name = "hdfs://192.168.0.182:9000/input/NYT_prediction_mtc{i}_{j}.parquet".format(i=i, j=j)
            df_write.write.parquet(dest_name, mode="overwrite")

            df_write = 0

        print("DONE")

    print("ALL SVM DONE round1_{h}".format(h=h+1))

    df_numbers.to_csv("NYT_round1_sample{h}_sample_numbers.csv".format(h=h+1), index=False)

    # empty memory
    spark.catalog.clearCache()
    print("cache cleared")

    #######################################################
    ### parquet to pandas
    #######################################################

    for j in range(10):
        # read from parquet format
        for i in majortopic_codes:
            source_name = "hdfs://192.168.0.182:9000/input/NYT_prediction_mtc{i}_{j}.parquet".format(i=i, j=j)
            df = spark.read.parquet(source_name).repartition(50)
            if i == 1:
                df_results = df
            else:
                df_results = df_results.join(df, 'doc_id', 'inner')

        df = df_results
        df_results = 0

        # convert prediction results to pandas df
        df = df.toPandas()

        df.to_csv("NYT_round1_sample{h}_svm{j}.csv".format(h=h+1,j=j), index=False)


sc.stop()
spark.stop()
